#!/usr/bin/python

# TODO:
# ! add option for padding
# - fix occasionally missing page numbers
# - treat large h-whitespace as separator
# - handle overlapping candidates
# - use cc distance statistics instead of character scale
# - page frame detection
# - read and use text image segmentation mask
# - pick up stragglers
# ? laplacian as well

import pdb
from pylab import *
import argparse,glob,os,os.path
from scipy.ndimage import filters,interpolation,morphology,measurements
from scipy import stats
from scipy.misc import imsave
from scipy.ndimage.filters import gaussian_filter,uniform_filter,maximum_filter,minimum_filter
import ocrolib
from ocrolib import psegutils,morph,improc,sl
import multiprocessing
from multiprocessing import Pool

parser = argparse.ArgumentParser()
parser.add_argument('-z','--zoom',type=float,default=0.5,help='zoom for page background estimation, smaller=faster')

parser.add_argument('--debug',type=float,default=0,help='show debug output')
parser.add_argument('--show',type=float,default=0,help='show the final output')

# limits
parser.add_argument('--minscale',type=float,default=12.0,help='minimum scale permitted')
parser.add_argument('--maxlines',type=float,default=300,help='maximum # lines permitted')

# scale parameters
parser.add_argument('--hscale',type=float,default=1.0,help='scaling of horizontal parameters')
parser.add_argument('--vscale',type=float,default=1.0,help='scaling of vertical parameters')
parser.add_argument('--threshold',type=float,default=0.03,help='baseline threshold')
parser.add_argument('--usegauss',action='store_true',help='use gaussian instead of uniform')

# column parameters
parser.add_argument('--cdebug',type=float,default=0,help='show debug output for columns')
parser.add_argument('--maxcols',type=int,default=2,help='maximum # columns')
parser.add_argument('--cmaxwidth',type=float,default=10,help='maximum column width (units=scale)')
parser.add_argument('--cminheight',type=float,default=20,help='minimum column height (units=scale)')

parser.add_argument('-p','--pad',type=int,default=3,help='padding for extracted lines')
parser.add_argument('-e','--expand',type=int,default=1,help='expand mask for grayscale extraction')
parser.add_argument('-Q','--parallel',type=int,default=0)
parser.add_argument('files',nargs='+')
args = parser.parse_args()
files = args.files

if args.debug:
    ion(); gray()

def compute_segmentation(binary,scale):
    binary = array(binary,'B')
    
    # start by computing columns
    cols = psegutils.compute_columns_morph(binary,scale,debug=args.cdebug,
                                           maxcols=args.maxcols,
                                           maxwidth=args.cmaxwidth,
                                           minheight=args.cminheight)
    seps = psegutils.compute_separators_morph(binary,scale)
    if args.debug:
        clf(); title("cols"); imshow(maximum(seps,0.7*cols+0.3*binary)); ginput(1,args.debug)
    cols = maximum(cols,seps)
    binary = minimum(binary,1-seps)
    
    # use gradient filtering to find baselines
    boxmap = psegutils.compute_boxmap(binary,scale)
    cleaned = boxmap*binary
    if args.debug:
        clf(); title("cleaned"); imshow(cleaned); ginput(1,args.debug)
    if args.usegauss:
        # this uses Gaussians
        grad = gaussian_filter(1.0*cleaned,(args.hscale*0.3*scale,
                                            args.vscale*6*scale),order=(1,0))
    else:
        # this uses non-Gaussian oriented filters
        grad = gaussian_filter(1.0*cleaned,(max(4,args.hscale*0.3*scale),
                                            args.vscale*scale),order=(1,0))
        grad = uniform_filter(grad,(args.hscale,args.vscale*6*scale))
    bottom = improc.norm_max((grad<0)*(-grad))
    if args.debug:
        clf(); title("filtered"); imshow(bottom); ginput(1,args.debug)

    # morphological postprocessing and local maxima to create lines
    bottom = maximum_filter(bottom,(2,2))
    spread = maximum_filter(bottom,(int(args.vscale*scale),0))
    loc = (bottom>args.threshold*amax(bottom))*(bottom>=spread)
    loc = minimum(loc,1-cols)

    # select the longest lines only
    loc = morph.select_regions(loc,sl.dim1,min=2*scale,nbest=1000)

    # thicken it vertically upwards to pick up most bounding boxes
    # on the baseline
    locm = maximum_filter(loc,(int(scale),1),origin=(-int(scale/2),0))

    # label lines, propagate to bounding boxes
    if args.debug:
        clf(); title("boxmap"); imshow(0.7*boxmap+0.3*binary); ginput(1,args.debug)
    lloc,_ = morph.label(locm)
    llines = maximum(lloc,morph.propagate_labels(boxmap,lloc,conflict=-1))

    # where there are conflicts, resolve using spread
    spread = morph.spread_labels(lloc,maxdist=scale)
    llines = where(llines==-1,spread,llines)
    if args.debug:
        clf(); title("spread"); morph.showlabels(llines); ginput(1,args.debug)

    # label all the remaining pixels
    llines = where(llines!=0,llines,binary*spread)
    if args.debug:
        clf(); title("remaining"); morph.showlabels(llines); ginput(1,args.debug)

    # now restrict to the foreground pixels
    segmentation = llines*binary
    if args.debug>0:
        clf(); title("segmentation"); imshow(segmentation,cmap=cm.flag); ginput(1,args.debug)

    return segmentation

def process1(job):
    fname,i = job
    base,_ = os.path.splitext(fname)
    if os.path.exists(base+".bin.png"):
        if os.path.exists(base+".nrm.png"):
            gray = ocrolib.read_image_gray(base+".nrm.png")
        else:
            gray = ocrolib.read_image_gray(fname)
        binary = 1-ocrolib.read_image_binary(base+".bin.png")
    else:
        binary = 1-ocrolib.read_image_binary(fname)
        gray = binary
    if args.debug:
        clf(); imshow(binary); title("binary"); ginput(1,maximum(args.show,args.debug))
    scale = psegutils.estimate_scale(binary)
    if scale<args.minscale:
        sys.stderr.write("%s: scale (%g) less than --minscale; skipping\n"%(fname,scale))
        return
    segmentation = compute_segmentation(binary,scale)
    lines = psegutils.compute_lines(segmentation,scale)
    order = psegutils.reading_order([l.bounds for l in lines],debug=args.debug)
    lsort = psegutils.topsort(order)
    if args.show or args.debug:
        clf(); title("output"); psegutils.show_lines(gray,lines,lsort); ginput(1,maximum(args.show,args.debug))
    base,_ = os.path.splitext(fname)
    if not os.path.exists(base):
        os.mkdir(base)
    lines = [lines[i] for i in lsort]
    if len(lines)>args.maxlines:
        sys.stderr.write("%s: scale (%d) less than --maxlines; skipping\n"%(fname,len(lines)))
        return
    pseg = transpose(array([ones(segmentation.shape,'B'),array(segmentation//256,'B'),array(segmentation,'B')]),[1,2,0])
    imsave("%s.pseg.png"%base,pseg)
    for i,l in enumerate(lines):
        clf()
        imsave("%s/01%04x.png"%(base,i+1),psegutils.extract_masked(gray,l,pad=args.pad,expand=args.expand))
        clf()
        imsave("%s/01%04x.bin.png"%(base,i+1),psegutils.extract_masked(1-binary,l,pad=args.pad))
    print "%6d"%i,fname,"%4.1f"%scale,len(lines)

if args.debug>0 or args.show>0: args.parallel = 0

if len(args.files)==1 and os.path.isdir(args.files[0]):
    files = glob.glob(args.files[0]+"/????.png")
else:
    files = args.files

if args.parallel<2:
    count = 0
    for i,f in enumerate(files):
        if args.parallel==0: print f
        count += 1
        process1((f,i+1))
else:
    pool = Pool(processes=args.parallel)
    jobs = []
    for i,f in enumerate(files): jobs += [(f,i+1)]
    result = pool.map(process1,jobs)
